%% Run this script to derive all EWM results in the manuscript 
% Key parameters characterizing the treatment rules are saved and called in
% analyze_4_summarize_ewm_rules.m to produce tables and figures. 

clear; clc; 
cd('..\..')

savedir='.\intermediate_data\EWM_results';
savedir_tables='.\output\tables';
savedir_figures='.\output\figures';

addpath ('.\intermediate_data')
addpath('.\code\analyze')
addpath('.\intermediate_data\EWM_results')

%% Main specification
individual_ps=0; % =0 if using treatment probability for each of the three waves
winter_prevuse=0; % =0 if pre-treatment average calculated using specified number
% of months prior to treatment
pre_avg_n=12; % 12 month pre-treatment period
post_avg_n=12; % 12 month post-treatment period
wave=0; % =0 if pooled; otherwise =3, 6, or 7

% Define tex output name
if (winter_prevuse)
    s_pre='pre_avg_winter_recent';
else
    switch pre_avg_n
        case 21
            s_pre='pre_avg_21';
        case 24
            s_pre='pre_avg_24';
        case 12
            s_pre='pre_avg_12';
    end
end

switch post_avg_n
    case 21
        s_post='post_avg_21';
    case 24
        s_post='post_avg_24';
    case 12
        s_post='post_avg_12';
    case 0
        s_post='all_post';
end
    
if individual_ps==1
    s_ps='ips';
else 
    s_ps='wps';
end

fbase=sprintf('%s_%s',s_ps,s_pre,'_',s_post); %fbase specifies output table prefix

if wave==0
    wave_name='pooled';
elseif wave==3
    wave_name='wave3';
elseif wave==6
    wave_name='wave6';
elseif wave==7
    wave_name='wave7';
end
%% Import sample 
sample_panel=readtable('elec_wave367_winterpre.csv');
[sample_cross] = generate_cross_sample(sample_panel,wave,winter_prevuse,pre_avg_n,individual_ps,post_avg_n);

clear sample_panel % panel data not used for ewm analysis

%% Covariates defined in terms of observable characteristics 
rules={'quadrant', 'cubic'}; 
fullsample=0; % full sample refers to the sample with non-missing pre-treatment consumption only 
% The main analysis uses the estimation sample with non-missing covariates and non-missing pre-treatment consumption 
bsimax_rules = {1000, 300}; % 1,000 bootstraps for quadrant rules and 300 bootstraps for cubic rules

% loop over all rules, costs, and covariates
for j=1:2 
    rule = rules{j};
    cost={true false true}; 
    SMC={0 0 1}; 
    EWM_savings_demographics=cell(3,1);
    Scaled_ATT=cell(3,1);
    ATE=cell(3,1);
    bsimax=bsimax_rules{j}; 

    for i=1:3
         if (cost{i})
            if (SMC{i})
                s_cost = 'SMC';
            else
                s_cost = 'PMC';
            end
        else
            s_cost='kwh';
         end

        % Calculate benchmark savings in terms of ATE and ATT
        [Scaled_ATT{i},ATE{i}]=calculate_benchmark_savings(sample_cross,wave,cost{i},SMC{i},bsimax);

        % Run to estimate EWM rules
        EWM_savings_demographics{i}=generate_tables_covariate(sample_cross,winter_prevuse,wave,cost{i},bsimax,SMC{i},savedir,fbase,rule,fullsample);
    end
end

%% Covariates defined in terms of functions of PMC, kWh, SMC baseline consumption
fullsample=0;

for j=1:2 
 rule = rules{j};
 cost={true false true}; 
 SMC={0 0 1}; 
 EWM_savings_baseline=cell(3,1);
 Scaled_ATT=cell(3,1);
 ATE=cell(3,1);
 bsimax=bsimax_rules{j};
 
 for i=1:3
      if (cost{i})
         if (SMC{i})
             s_cost = 'SMC';
         else
             s_cost = 'PMC';
         end
      else
         s_cost='kwh';
      end
     
     % Calculate benchmark savings in terms of ATE and ATT
     [Scaled_ATT{i},ATE{i}]=calculate_benchmark_savings(sample_cross,wave,cost{i},SMC{i},bsimax);
 
     % Run to estimate EWM rules
     EWM_savings_baseline{i}=generate_tables_baseline(sample_cross,winter_prevuse,wave,cost{i},bsimax,SMC{i},savedir,fbase,rule,fullsample);
 end
end 


 %% Generate point estimates and CIs for delta ewm v. rct 
 % note: this set of CIs are constructed by applying the estimated ewm
 % rules on each bootstrap sample, not the CIs based on K&T. 
 bsimax=1000;
 cost={true false true}; 
 SMC={0 0 1}; 
 covariates={'income','size','vintage','min','max','std'};

 for r=1:2
     rule = rules{r};
     for i=1:3
          if (cost{i})
             if (SMC{i})
                 s_cost = 'SMC';
             else
                 s_cost = 'PMC';
             end
         else
             s_cost='kwh';
         end

         if (cost{i})
             tcost = 0.765; 
         else
             tcost = 0;
         end

        for j=1:6
        generate_delta_CI_bootstrap(bsimax,sample_cross,rule,wave,tcost,covariates{j},winter_prevuse,SMC{i},savedir,fbase)
        end
     end
 end
%% Cross waves analysis 
bsimax=0;
 for i=1:3
[savings_cv_quadrant_wave3,savings_cv_quadrant_wave6,savings_cv_quadrant_wave7,...
    savings_cv_cubic_wave3,savings_cv_cubic_wave6,savings_cv_cubic_wave7] = ...
    generate_results_cross_waves(sample_cross,winter_prevuse,cost{i},bsimax,SMC{i},savedir,fbase);
 end

%% Out of sample performance
cost={false true true}; 
SMC={0 0 1}; 
covariates={'income','size','vintage'};
pm=100; % 100 permutations
reweight=0; % no reweight between sample wave and target wave
savedir_cv_pooled='.\intermediate_data\EWM_results\cv_pooled';
addpath('.\intermediate_data\EWM_results\cv_pooled')

seed = 1:pm; % set seed for each permutation random draw
split_choice=0.5;
% generate testing and training samples
[idx_t,idx_h] = generate_cv_samples_permutation(sample_cross,pm,seed,split_choice,savedir_cv_pooled,fbase);

% estimate ewm savings ------------------------------
ewm_cv_pooled_quadrant=cell(3,3);
ewm_cv_pooled_cubic=cell(3,3);
ewm_cv_pooled_univariate=cell(3,3);

rules={'univariate','quadrant', 'cubic'}; 
for r=1:size(rules,2)
    rule=rules{r};
        for j=2:3 
             for i=1:3 
                  if (cost{i})
                     if (SMC{i})
                         s_cost = 'SMC';
                     else
                         s_cost = 'PMC';
                     end
                 else
                     s_cost='kwh';
                 end

                 if (cost{i})
                     tcost = 0.765; 
                 else
                     tcost = 0;
                 end

                if rule=="quadrant"
                    ewm_cv_pooled_quadrant{i,j} = generate_cv_results_permutation(sample_cross,pm,split_choice,idx_t,idx_h,rule,reweight,tcost,covariates{j},winter_prevuse,SMC{i},savedir_cv_pooled,fbase);
                elseif rule=="cubic"
                    ewm_cv_pooled_cubic{i,j} = generate_cv_results_permutation(sample_cross,pm,split_choice,idx_t,idx_h,rule,reweight,tcost,covariates{j},winter_prevuse,SMC{i},savedir_cv_pooled,fbase);
                elseif rule=="univariate"
                    ewm_cv_pooled_univariate{i,j} = generate_cv_results_permutation(sample_cross,pm,split_choice,idx_t,idx_h,rule,reweight,tcost,covariates{j},winter_prevuse,SMC{i},savedir_cv_pooled,fbase);
                end
             end
        end
end

%% Budget constraint Analysis 
wave=0;
bsimax=0; 
cost={false true true}; 
SMC={0 0 1}; 
covariates={'income','size','vintage'};
share = mean(sample_cross.opower);

rule_set=["univariate","quadrant"];
fix_range=0;

for r=1:size(rule_set,2)
    rule=rule_set{1,r};
    if rule=="quadrant"
         ewm_capshare_pooled_quadrant=cell(size(cost,2),size(covariates,2));
         ewm_fixedshare_pooled_quadrant=cell(size(cost,2),size(covariates,2));
    elseif rule=="cubic"
         ewm_capshare_pooled_cubic=cell(size(cost,2),size(covariates,2));
         ewm_fixedshare_pooled_cubic=cell(size(cost,2),size(covariates,2));
    elseif rule=="univariate"
         ewm_capshare_pooled_univariate=cell(size(cost,2),size(covariates,2));
         ewm_fixedshare_pooled_univariate=cell(size(cost,2),size(covariates,2));
    end

        for j=1:size(covariates,2) 
             for i=1:size(cost,2)
                  if (cost{i})
                     if (SMC{i})
                         s_cost = 'SMC';
                     else
                         s_cost = 'PMC';
                     end
                 else
                     s_cost='kwh';
                 end

                 if (cost{i})
                     tcost = 0.765;
                 else
                     tcost = 0;
                 end

                if rule=="quadrant"
                    ewm_capshare_pooled_quadrant{i,j} = generate_results_capshare_quadrant_rule(bsimax,sample_cross,share,wave,tcost,covariates{j},winter_prevuse,SMC{i},savedir,fbase);
                    ewm_fixedshare_pooled_quadrant{i,j} = generate_results_fixedshare_quadrant_rule(bsimax,sample_cross,share,fix_range,wave,tcost,covariates{j},winter_prevuse,SMC{i},savedir,fbase);
                   elseif rule=="univariate"
                    ewm_capshare_pooled_univariate{i,j} = generate_results_capshare_onedimension(bsimax,sample_cross,share,wave,tcost,covariates{j},winter_prevuse,SMC{i},savedir,fbase);
                    ewm_fixedshare_pooled_univariate{i,j} = generate_results_fixedshare_onedimension(bsimax,sample_cross,share,fix_range,wave,tcost,covariates{j},winter_prevuse,SMC{i},savedir,fbase);
                end

             end
        end

end

 
 %% Full sample analysis
rules={'quadrant', 'cubic'}; 
fullsample=1; % full sample refers to the sample with non-missing pre-treatment consumption only
% The main analysis uses the estimation sample with non-missing covariates and non-missing pre-treatment consumption 
sample_panel_full=readtable('elec_wave367_winterpre_nomissing_baseline.csv');
[sample_cross_full] = generate_cross_sample_baseline_fullsample(sample_panel_full,wave,winter_prevuse,pre_avg_n,individual_ps,post_avg_n);
bsimax=0;

% loop over all rules, costs, and covariates
for j=1:2 
 rule = rules{j};
 cost={true false true}; 
 SMC={0 0 1}; 
 EWM_savings_baseline=cell(3,1);
 Scaled_ATT=cell(3,1);
 ATE=cell(3,1);
 
 for i=1:3
      if (cost{i})
         if (SMC{i})
             s_cost = 'SMC';
         else
             s_cost = 'PMC';
         end
     else
         s_cost='kwh';
     end
 
     if (cost{i})
         tcost = 0.765; 
     else
         tcost = 0;
     end
     
     % Calculate benchmark savings in terms of ATE and ATT
     [Scaled_ATT{i},ATE{i}]=calculate_benchmark_savings(sample_cross_full,wave,cost{i},SMC{i},bsimax);
 
     % Run to estimate EWM rules
     EWM_savings_baseline{i}=generate_tables_baseline(sample_cross_full,winter_prevuse,wave,cost{i},bsimax,SMC{i},savedir,fbase,rule,fullsample);
 end
end 